# Copyright 2018- The Pixie Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

import argparse
import yaml
import os
import re


def grab_address(fname):
    desired_extensions = ['.crt', '.key']
    unwanted_extensions = ['.json', '.issuer.crt']
    address = ''
    for e in desired_extensions:
        if e in fname:
            address = fname.replace(e, '')
            break

    for e in unwanted_extensions:
        if e in fname:
            return ""

    return address


def make_address_file_map(directory):
    files_by_address = {}
    for fname in os.listdir(directory):
        address = grab_address(fname)
        if address == '':
            continue
        if address not in files_by_address:
            files_by_address[address] = []
        files_by_address[address].append(fname)
    return files_by_address


def make_file_content_dictionary(parent_directory, parent_key, files):
    file_contents_map = {}
    for fname in files:
        with open(os.path.join(parent_directory, fname)) as f:
            file_contents_map[fname.replace(
                "{}.".format(parent_key), "")] = f.read().strip()
    return file_contents_map


def format_file_dictionary(parent_directory, address_to_files_map):
    address_to_file_content_map = {}
    for address, files in address_to_files_map.items():
        address_to_file_content_map[address] = make_file_content_dictionary(
            parent_directory, address, files)
    return address_to_file_content_map


def get_namespace(address):
    matches = re.findall('clusters.*dev|clusters.*ai', address)
    assert len(matches) == 1, "num matches = {}; address = {}".format(
        len(matches), address)
    return matches[0]


def split_by_cluster(address_to_file_content_map, cluster_to_fname_mapping):
    fname_to_file_content_mapping = {}
    for address, files in address_to_file_content_map.items():
        ns = get_namespace(address)
        # hack to get this working.
        if ns not in cluster_to_fname_mapping:
            continue
        fname = cluster_to_fname_mapping[ns]
        if fname not in fname_to_file_content_mapping:
            fname_to_file_content_mapping[fname] = {}
        fname_to_file_content_mapping[fname][address] = files
    return fname_to_file_content_mapping


def parentdir(path):
    return os.path.normpath(os.path.join(path, os.pardir))


def write_to_files(file_name_to_file_contents_dict):
    for fname, fcontents in file_name_to_file_contents_dict.items():
        file_parent = parentdir(fname)
        print(file_parent)
        if not os.path.exists(file_parent):
            os.makedirs(file_parent)
        with open(fname, 'w') as yaml_file:
            yaml.dump(fcontents, yaml_file, default_flow_style=False)


def parse_args():
    parser = argparse.ArgumentParser(
        description='Assemble the yaml for CA certs generated by LEGO.')
    parser.add_argument(
        'certsdir', help='The directory that contains the certs.')
    parser.add_argument('address', help="The address of the certs to grab.")
    parser.add_argument('outfile', help='The output file.')
    return parser.parse_args()


def main():
    # YAML package setup for better printing.
    def str_presenter(dumper, data):
        if len(data.splitlines()) > 1:  # check for multiline string
            return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
        return dumper.represent_scalar('tag:yaml.org,2002:str', data)
    yaml.add_representer(str, str_presenter)

    args = parse_args()
    cluster_to_fname_mapping = {
        args.address: args.outfile,
    }

    files_by_address = make_address_file_map(args.certsdir)

    d = format_file_dictionary(args.certsdir, files_by_address)
    file_name_to_file_dict = split_by_cluster(d, cluster_to_fname_mapping)

    write_to_files(file_name_to_file_dict)


if __name__ == "__main__":
    main()
